import pytest
from django.test import RequestFactory
from unittest.mock import Mock, patch, MagicMock
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from datetime import datetime
import threading
from collections import defaultdict

from ..views import SearchProductView


class TestSearchProductViewLoadTest:
    """Load test suite for SearchProductView with 500 concurrent users"""
    
    @pytest.fixture
    def view(self):
        """Create view instance"""
        return SearchProductView()
    
    @pytest.fixture
    def factory(self):
        """Create request factory"""
        return RequestFactory()
    
    @pytest.fixture
    def mock_dependencies(self):
        """Mock all external dependencies"""
        with patch('products.views.Product') as mock_product, \
             patch('products.views.ActiveScraper') as mock_scraper, \
             patch('products.views.scrape_jumia_task') as mock_jumia, \
             patch('products.views.scrape_konga_task') as mock_konga, \
             patch('products.views.scrape_slot_task') as mock_slot, \
             patch('products.views.ProductSerializer') as mock_serializer:
            
            # Setup mock Product queryset
            mock_queryset = Mock()
            mock_queryset.filter.return_value = mock_queryset
            mock_queryset.order_by.return_value = mock_queryset
            mock_queryset.values_list.return_value = mock_queryset
            mock_queryset.distinct.return_value = ['jumia', 'konga']
            mock_queryset.count.return_value = 100
            
            # Mock paginated results
            mock_products = [Mock(id=i, name=f"Product {i}", site="jumia") for i in range(20)]
            mock_queryset.__iter__ = Mock(return_value=iter(mock_products))
            mock_queryset.__len__ = Mock(return_value=20)
            
            mock_product.objects = mock_queryset
            
            # Setup mock ActiveScraper
            mock_active_queryset = Mock()
            mock_active_queryset.filter.return_value = mock_active_queryset
            mock_active_queryset.values_list.return_value = mock_active_queryset
            mock_active_queryset.distinct.return_value = []
            mock_scraper.objects = mock_active_queryset
            
            # Setup mock Celery tasks
            mock_task = Mock()
            mock_task.id = "test-task-id-123"
            mock_jumia.delay.return_value = mock_task
            mock_konga.delay.return_value = mock_task
            mock_slot.delay.return_value = mock_task
            
            # Setup mock serializer
            mock_serializer.return_value.data = [{"id": i, "name": f"Product {i}"} for i in range(20)]
            
            yield {
                'product': mock_product,
                'scraper': mock_scraper,
                'jumia': mock_jumia,
                'konga': mock_konga,
                'slot': mock_slot,
                'serializer': mock_serializer
            }
    
    def make_request(self, factory, view, query="laptop", page=1, site=None):
        """Helper to create and execute a single request"""
        params = f"?q={query}&page={page}"
        if site:
            params += f"&site={site}"
        
        request = factory.get(f'/api/search/{params}')
        response = view.get(request)
        return response
    
    def test_single_request_baseline(self, view, factory, mock_dependencies):
        """Baseline test - single request works correctly"""
        response = self.make_request(factory, view)
        
        assert response.status_code == 200
        assert 'results' in response.data
        assert 'scraping_status' in response.data
        assert 'sites_available' in response.data
    
    def test_concurrent_same_query(self, view, factory, mock_dependencies):
        """Test 500 users searching for the same query"""
        num_users = 500
        query = "laptop"
        
        results = {
            'success': 0,
            'errors': 0,
            'response_times': [],
            'error_details': defaultdict(int)
        }
        lock = threading.Lock()
        
        def execute_request(user_id):
            """Execute a single request and track results"""
            try:
                start = time.time()
                response = self.make_request(factory, view, query=query)
                elapsed = time.time() - start
                
                with lock:
                    results['response_times'].append(elapsed)
                    if response.status_code == 200:
                        results['success'] += 1
                    else:
                        results['errors'] += 1
                        results['error_details'][response.status_code] += 1
                
                return {'user_id': user_id, 'status': response.status_code, 'time': elapsed}
            
            except Exception as e:
                with lock:
                    results['errors'] += 1
                    results['error_details'][str(type(e).__name__)] += 1
                return {'user_id': user_id, 'error': str(e)}
        
        # Execute concurrent requests
        print(f"\n{'='*60}")
        print(f"Starting load test: {num_users} concurrent users")
        print(f"{'='*60}")
        
        start_time = time.time()
        
        with ThreadPoolExecutor(max_workers=500) as executor:
            futures = [executor.submit(execute_request, i) for i in range(num_users)]
            
            # Track progress
            completed = 0
            for future in as_completed(futures):
                completed += 1
                if completed % 1000 == 0:
                    print(f"Progress: {completed}/{num_users} requests completed")
        
        total_time = time.time() - start_time
        
        # Calculate statistics
        response_times = results['response_times']
        avg_response = sum(response_times) / len(response_times) if response_times else 0
        min_response = min(response_times) if response_times else 0
        max_response = max(response_times) if response_times else 0
        
        # Sort for percentile calculations
        sorted_times = sorted(response_times)
        p50 = sorted_times[len(sorted_times) // 2] if sorted_times else 0
        p95 = sorted_times[int(len(sorted_times) * 0.95)] if sorted_times else 0
        p99 = sorted_times[int(len(sorted_times) * 0.99)] if sorted_times else 0
        
        # Print results
        print(f"\n{'='*60}")
        print(f"LOAD TEST RESULTS")
        print(f"{'='*60}")
        print(f"Total Requests:        {num_users}")
        print(f"Successful:            {results['success']} ({results['success']/num_users*100:.2f}%)")
        print(f"Failed:                {results['errors']} ({results['errors']/num_users*100:.2f}%)")
        print(f"Total Time:            {total_time:.2f}s")
        print(f"Requests/Second:       {num_users/total_time:.2f}")
        print(f"\nResponse Times:")
        print(f"  Average:             {avg_response*1000:.2f}ms")
        print(f"  Min:                 {min_response*1000:.2f}ms")
        print(f"  Max:                 {max_response*1000:.2f}ms")
        print(f"  50th Percentile:     {p50*1000:.2f}ms")
        print(f"  95th Percentile:     {p95*1000:.2f}ms")
        print(f"  99th Percentile:     {p99*1000:.2f}ms")
        
        if results['error_details']:
            print(f"\nError Breakdown:")
            for error_type, count in results['error_details'].items():
                print(f"  {error_type}: {count}")
        
        print(f"{'='*60}\n")
        
        # Assertions
        assert results['success'] > num_users * 0.95, "Less than 95% success rate"
        assert avg_response < 1.0, f"Average response time too high: {avg_response}s"
        assert p95 < 2.0, f"95th percentile response time too high: {p95}s"
    
    def test_concurrent_varied_queries(self, view, factory, mock_dependencies):
        """Test 500 users with varied search queries"""
        num_users = 500
        queries = ["laptop", "phone", "tablet", "headphones", "camera", 
                   "keyboard", "mouse", "monitor", "speaker", "charger"]
        
        results = {'success': 0, 'errors': 0}
        lock = threading.Lock()
        
        def execute_request(user_id):
            query = queries[user_id % len(queries)]
            page = (user_id % 5) + 1
            
            try:
                response = self.make_request(factory, view, query=query, page=page)
                with lock:
                    if response.status_code == 200:
                        results['success'] += 1
                    else:
                        results['errors'] += 1
            except Exception:
                with lock:
                    results['errors'] += 1
        
        print(f"\nTesting {num_users} users with varied queries...")
        start_time = time.time()
        
        with ThreadPoolExecutor(max_workers=500) as executor:
            futures = [executor.submit(execute_request, i) for i in range(num_users)]
            for _ in as_completed(futures):
                pass
        
        total_time = time.time() - start_time
        
        print(f"Completed in {total_time:.2f}s")
        print(f"Success: {results['success']}, Errors: {results['errors']}")
        
        assert results['success'] > num_users * 0.95
    
    def test_concurrent_with_site_filter(self, view, factory, mock_dependencies):
        """Test concurrent requests with site filtering"""
        num_users = 500
        sites = ['jumia', 'konga', 'slot.ng', None]
        
        results = {'success': 0, 'errors': 0}
        lock = threading.Lock()
        
        def execute_request(user_id):
            site = sites[user_id % len(sites)]
            try:
                response = self.make_request(factory, view, query="phone", site=site)
                with lock:
                    if response.status_code == 200:
                        results['success'] += 1
                    else:
                        results['errors'] += 1
            except Exception:
                with lock:
                    results['errors'] += 1
        
        print(f"\nTesting {num_users} users with site filters...")
        start_time = time.time()
        
        with ThreadPoolExecutor(max_workers=500) as executor:
            futures = [executor.submit(execute_request, i) for i in range(num_users)]
            for _ in as_completed(futures):
                pass
        
        total_time = time.time() - start_time
        
        print(f"Completed in {total_time:.2f}s")
        print(f"Success: {results['success']}, Errors: {results['errors']}")
        
        assert results['success'] > num_users * 0.95
    
    def test_database_connection_pooling(self, view, factory, mock_dependencies):
        """Test that database connections are properly managed under load"""
        num_users = 500
        
        # Track mock calls to ensure proper resource management
        call_counts = {'filter': 0, 'values_list': 0}
        lock = threading.Lock()
        
        original_filter = mock_dependencies['product'].objects.filter
        
        def tracked_filter(*args, **kwargs):
            with lock:
                call_counts['filter'] += 1
            return original_filter(*args, **kwargs)
        
        mock_dependencies['product'].objects.filter = tracked_filter
        
        def execute_request(user_id):
            try:
                self.make_request(factory, view, query="laptop")
            except Exception:
                pass
        
        print(f"\nTesting database connection management with {num_users} users...")
        
        with ThreadPoolExecutor(max_workers=500) as executor:
            futures = [executor.submit(execute_request, i) for i in range(num_users)]
            for _ in as_completed(futures):
                pass
        
        print(f"Total filter calls: {call_counts['filter']}")
        
        # Should have one filter call per request (or more if including ActiveScraper checks)
        assert call_counts['filter'] >= num_users
    
    @pytest.mark.parametrize("num_users", [100, 200, 300, 400, 500])
    def test_scalability(self, view, factory, mock_dependencies, num_users):
        """Test scalability at different load levels"""
        results = {'success': 0}
        lock = threading.Lock()
        
        def execute_request(user_id):
            try:
                response = self.make_request(factory, view, query="test")
                if response.status_code == 200:
                    with lock:
                        results['success'] += 1
            except Exception:
                pass
        
        start_time = time.time()
        
        with ThreadPoolExecutor(max_workers=min(500, num_users)) as executor:
            futures = [executor.submit(execute_request, i) for i in range(num_users)]
            for _ in as_completed(futures):
                pass
        
        total_time = time.time() - start_time
        throughput = num_users / total_time
        
        print(f"\n{num_users} users: {total_time:.2f}s, {throughput:.2f} req/s")
        
        assert results['success'] > num_users * 0.90